package drds.plus.sql_process.parser.visitor;

import drds.plus.parser.abstract_syntax_tree.expression.Expression;
import drds.plus.parser.abstract_syntax_tree.expression.Pair;
import drds.plus.parser.abstract_syntax_tree.expression.primary.misc.Identifier;
import drds.plus.parser.abstract_syntax_tree.expression.primary.misc.RowValues;
import drds.plus.parser.abstract_syntax_tree.statement.DeleteStatement;
import drds.plus.parser.abstract_syntax_tree.statement.InsertStatement;
import drds.plus.parser.abstract_syntax_tree.statement.Query;
import drds.plus.parser.abstract_syntax_tree.statement.ReplaceStatement;
import drds.plus.parser.abstract_syntax_tree.statement.select.table.Table;
import drds.plus.parser.visitor.VisitorImpl;

import java.util.List;
import java.util.Map;

/**
 * 将parser的语法树直接生成sql，允许替换表名
 */
public class SqlVisitor extends VisitorImpl {

    // 表名替换使用，注意表名都需为大写
    private Map<String/* logic where */, String/* real where */> logicTableNameToRealTableNameMap;

    public SqlVisitor(StringBuilder sb, Map<String, String> logicTableNameToRealTableNameMap) {
        this(sb, null, logicTableNameToRealTableNameMap);
    }

    public SqlVisitor(StringBuilder sb, Object[] args, Map<String, String> logicTableNameToRealTableNameMap) {
        super(sb, args);
        this.logicTableNameToRealTableNameMap = logicTableNameToRealTableNameMap;
    }

    private String getRealTableName(String tableName) {
        if (tableName == null) {
            return null;
        }

        String realTableName = tableName;
        if (logicTableNameToRealTableNameMap != null) {
            realTableName = logicTableNameToRealTableNameMap.get(tableName);
        }

        return realTableName;
    }

    public void visit(Table table) {
        // 逻辑表名
        Identifier tableName = table.getTableName();
        // 表名替换
        String realTableName = getRealTableName(tableName.getText());
        if (realTableName != null) {
            Identifier realTable = new Identifier(tableName.getParent(), realTableName);
            realTable.accept(this);
        } else {
            tableName.accept(this);
        }
        String alias = table.getAlias();
        if (alias != null) {
            sb.append(" AS ").append(alias);
        }

    }

    public void visit(DeleteStatement deleteStatement) {
        sb.append("DELETE ");
        sb.append(" FROM ");
        deleteStatement.getTable().accept(this);
        Expression where = deleteStatement.getWhere();
        if (where != null) {
            sb.append(" WHERE ");
            where.accept(this);
        }

    }

    public void visit(InsertStatement insertStatement) {
        sb.append("INSERT ");
        sb.append("INTO ");
        // 表名替换
        Identifier tableName = insertStatement.getTableName();
        String realTableName = getRealTableName(tableName.getText());
        if (realTableName != null) {
            Identifier realTable = new Identifier(tableName.getParent(), realTableName);
            realTable.accept(this);
        } else {
            tableName.accept(this);
        }

        sb.append(' ');

        List<Identifier> columnNameList = insertStatement.getColumnNameList();
        if (columnNameList != null && !columnNameList.isEmpty()) {
            sb.append('(');
            printList(columnNameList);
            sb.append(") ");
        }

        Query select = insertStatement.getQuery();
        if (select == null) {
            sb.append("VALUES ");
            List<RowValues> rowValuesList = insertStatement.getRowValuesList();
            if (rowValuesList != null && !rowValuesList.isEmpty()) {
                boolean first = true;
                for (RowValues rowValues : rowValuesList) {
                    if (rowValues == null || rowValues.getRowValueList().isEmpty())
                        continue;
                    if (first)
                        first = false;
                    else
                        sb.append(", ");
                    sb.append('(');
                    printList(rowValues.getRowValueList());
                    sb.append(')');
                }

            } else {
                throw new IllegalArgumentException("at least one row for insert");
            }
        } else {
            select.accept(this);
        }

        List<Pair<Identifier, Expression>> pairList = insertStatement.getDuplicateUpdate();
        if (pairList != null && !pairList.isEmpty()) {
            sb.append(" ON DUPLICATE KEY UPDATE ");
            boolean first = true;
            for (Pair<Identifier, Expression> pair : pairList) {
                if (first)
                    first = false;
                else
                    sb.append(", ");
                pair.getKey().accept(this);
                sb.append(" = ");
                pair.getValue().accept(this);
            }
        }
    }

    public void visit(ReplaceStatement replaceStatement) {
        sb.append("REPLACE ");
        sb.append("INTO ");
        // 表名替换
        Identifier tableName = replaceStatement.getTableName();
        String realTableName = getRealTableName(tableName.getText());
        if (realTableName != null) {
            Identifier realTable = new Identifier(tableName.getParent(), realTableName);
            realTable.accept(this);
        } else {
            tableName.accept(this);
        }
        sb.append(' ');

        List<Identifier> columnNameList = replaceStatement.getColumnNameList();
        if (columnNameList != null && !columnNameList.isEmpty()) {
            sb.append('(');
            printList(columnNameList);
            sb.append(") ");
        }

        Query select = replaceStatement.getQuery();
        if (select == null) {
            sb.append("VALUES ");
            List<RowValues> rowValuesList = replaceStatement.getRowValuesList();
            if (rowValuesList != null && !rowValuesList.isEmpty()) {
                boolean first = true;
                for (RowValues rowValues : rowValuesList) {
                    if (rowValues == null || rowValues.getRowValueList().isEmpty())
                        continue;
                    if (first)
                        first = false;
                    else
                        sb.append(", ");
                    sb.append('(');
                    printList(rowValues.getRowValueList());
                    sb.append(')');
                }
            } else {
                throw new IllegalArgumentException("at least one row for replace");
            }
        } else {
            select.accept(this);
        }
    }

    public void setLogicTableNameToRealTableNameMap(Map<String, String> logicTableNameToRealTableNameMap) {
        this.logicTableNameToRealTableNameMap = logicTableNameToRealTableNameMap;
    }

}
